{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G3MMAcssHTML"
      },
      "source": [
        "<link rel=\"stylesheet\" href=\"/site-assets/css/gemma.css\">\n",
        "<link rel=\"stylesheet\" href=\"https://fonts.googleapis.com/css2?family=Google+Symbols:opsz,wght,FILL,GRAD@20..48,100..700,0..1,-50..200\" />"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FUOiKRSF7jc1"
      },
      "source": [
        "# Inference with RecurrentGemma using JAX and Flax"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "60KmTK7o6ppd"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://ai.google.dev/gemma/docs/recurrentgemma/recurrentgemma_jax_inference\"><img src=\"https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png\" height=\"32\" width=\"32\" />View on ai.google.dev</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/recurrentgemma/recurrentgemma_jax_inference.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/google/generative-ai-docs/main/site/en/gemma/docs/recurrentgemma/recurrentgemma_jax_inference.ipynb\"><img src=\"https://ai.google.dev/images/cloud-icon.svg\" width=\"40\" />Open in Vertex AI</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/recurrentgemma/recurrentgemma_jax_inference.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tdlq6K0znh3O"
      },
      "source": [
        "This tutorial demonstrates how to perform basic sampling/inference with the [RecurrentGemma](https://ai.google.dev/gemma/docs/recurrentgemma) 2B Instruct model using [Google DeepMind's `recurrentgemma` library](https://github.com/google-deepmind/recurrentgemma) that was written with [JAX](https://jax.readthedocs.io) (a high-performance numerical computing library), [Flax](https://flax.readthedocs.io) (the JAX-based neural network library), [Orbax](https://orbax.readthedocs.io/) (a JAX-based library for training utilities like checkpointing), and [SentencePiece](https://github.com/google/sentencepiece) (a tokenizer/detokenizer library). Although Flax is not used directly in this notebook, Flax was used to create Gemma and RecurrentGemma (the Griffin model).\n",
        "\n",
        "This notebook can run on Google Colab with the T4 GPU (go to **Edit** > **Notebook settings** > Under **Hardware accelerator** select **T4 GPU**)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aKvTsIkL98BG"
      },
      "source": [
        "## Setup\n",
        "\n",
        "The following sections explain the steps for preparing a notebook to use a RecurrentGemma model, including model access, getting an API key, and configuring the notebook runtime"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WCgCkmQSPxkE"
      },
      "source": [
        "### Set up Kaggle access for Gemma\n",
        "\n",
        "To complete this tutorial, you first need to follow the setup instructions _similar_ to [Gemma setup](https://ai.google.dev/gemma/docs/setup) with a few exceptions:\n",
        "\n",
        "* Get access to RecurrentGemma (instead of Gemma) on [kaggle.com](https://www.kaggle.com/models/google/recurrentgemma).\n",
        "* Select a Colab runtime with sufficient resources to run the RecurrentGemma model.\n",
        "* Generate and configure a Kaggle username and API key.\n",
        "\n",
        "After you've completed the RecurrentGemma setup, move on to the next section, where you'll set environment variables for your Colab environment.\n",
        "\n",
        "### Set environment variables\n",
        "\n",
        "Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`. When prompted with the \"Grant access?\" messages, agree to provide secret access."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lKoW-nhE-gNO"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata # `userdata` is a Colab API.\n",
        "\n",
        "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
        "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AO7a1Q4Yyc9Z"
      },
      "source": [
        "### Install the `recurrentgemma` library\n",
        "\n",
        "This notebook focuses on using a free Colab GPU. To enable hardware acceleration, click on **Edit** > **Notebook settings** > Select **T4 GPU** > **Save**.\n",
        "\n",
        "Next, you need to install the Google DeepMind `recurrentgemma` library from [`github.com/google-deepmind/recurrentgemma`](https://github.com/google-deepmind/recurrentgemma). If you get an error about \"pip's dependency resolver\", you can usually ignore it.\n",
        "\n",
        "**Note:** By installing `recurrentgemma`, you will also install [`flax`](https://flax.readthedocs.io), core [`jax`](https://jax.readthedocs.io), [`optax`](https://optax.readthedocs.io/en/latest/) (the JAX-based gradient processing and optimization library), [`orbax`](https://orbax.readthedocs.io/), and [`sentencepiece`](https://github.com/google/sentencepiece)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WWEzVJR4Fx9g"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Collecting git+https://github.com/google-deepmind/recurrentgemma.git\n",
            "  Cloning https://github.com/google-deepmind/recurrentgemma.git to /tmp/pip-req-build-zz9xp6s4\n",
            "  Running command git clone --filter=blob:none --quiet https://github.com/google-deepmind/recurrentgemma.git /tmp/pip-req-build-zz9xp6s4\n",
            "  Resolved https://github.com/google-deepmind/recurrentgemma.git to commit e4939f9b7edf8baa1d512fb86bfc2e206044d66b\n",
            "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
            "  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "Requirement already satisfied: absl-py<1.5.0,>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from recurrentgemma==0.1.0) (1.4.0)\n",
            "Collecting einops<0.8.0,>=0.7.0 (from recurrentgemma==0.1.0)\n",
            "  Downloading einops-0.7.0-py3-none-any.whl (44 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.6/44.6 kB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting jaxtyping<0.3.0,>=0.2.28 (from recurrentgemma==0.1.0)\n",
            "  Downloading jaxtyping-0.2.28-py3-none-any.whl (40 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.7/40.7 kB\u001b[0m \u001b[31m3.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: numpy<2.0,>=1.21 in /usr/local/lib/python3.10/dist-packages (from recurrentgemma==0.1.0) (1.25.2)\n",
            "Collecting sentencepiece<0.3.0,>=0.2.0 (from recurrentgemma==0.1.0)\n",
            "  Downloading sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting typeguard==2.13.3 (from jaxtyping<0.3.0,>=0.2.28->recurrentgemma==0.1.0)\n",
            "  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)\n",
            "Building wheels for collected packages: recurrentgemma\n",
            "  Building wheel for recurrentgemma (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for recurrentgemma: filename=recurrentgemma-0.1.0-py3-none-any.whl size=73547 sha256=e3d3e85d59877ec33d2e4dff1a1666eaed1342c68199255cdd806d74472d4524\n",
            "  Stored in directory: /tmp/pip-ephem-wheel-cache-42qdygtw/wheels/31/37/18/c57f1df6091b661385ab728b959bdfbf2078d9fc7c856899e4\n",
            "Successfully built recurrentgemma\n",
            "Installing collected packages: sentencepiece, typeguard, einops, jaxtyping, recurrentgemma\n",
            "  Attempting uninstall: sentencepiece\n",
            "    Found existing installation: sentencepiece 0.1.99\n",
            "    Uninstalling sentencepiece-0.1.99:\n",
            "      Successfully uninstalled sentencepiece-0.1.99\n",
            "Successfully installed einops-0.7.0 jaxtyping-0.2.28 recurrentgemma-0.1.0 sentencepiece-0.2.0 typeguard-2.13.3\n"
          ]
        }
      ],
      "source": [
        "!pip install git+https://github.com/google-deepmind/recurrentgemma.git"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VKLjBAe1m3Ck"
      },
      "source": [
        "## Load and prepare the RecurrentGemma model\n",
        "\n",
        "1. Load the RecurrentGemma model with [`kagglehub.model_download`](https://github.com/Kaggle/kagglehub/blob/bddefc718182282882b72f814d407d89e5d178c4/src/kagglehub/models.py#L12), which takes three arguments:\n",
        "\n",
        "- `handle`: The model handle from Kaggle\n",
        "- `path`: (Optional string) The local path\n",
        "- `force_download`: (Optional boolean) Forces to re-download the model\n",
        "\n",
        "**Note:** Be mindful that the `recurrentgemma-2b-it` model is around 3.85Gb in size."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_W3FUd9lt8VT"
      },
      "outputs": [],
      "source": [
        "RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:\"string\"}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kFCmWEKdMA_Y"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...\n",
            "100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]\n",
            "Extracting model files...\n"
          ]
        }
      ],
      "source": [
        "import kagglehub\n",
        "\n",
        "RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nYmYTMk8aELi"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1\n"
          ]
        }
      ],
      "source": [
        "print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ytNi47xSlw71"
      },
      "source": [
        "**Note:** The path from the output above is where the model weights and tokenizer are saved locally, you will need them for later."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "92BcvYdemXbd"
      },
      "source": [
        "2. Check the location of the model weights and the tokenizer, then set the path variables. The tokenizer directory will be in the main directory where you downloaded the model, while the model weights will be in a sub-directory. For example:\n",
        "\n",
        "- The `tokenizer.model` file will be in `/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1`).\n",
        "- The model checkpoint will be in `/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it`)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QY6OnASOpZbW"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it\n",
            "TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model\n"
          ]
        }
      ],
      "source": [
        "CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)\n",
        "TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')\n",
        "print('CKPT_PATH:', CKPT_PATH)\n",
        "print('TOKENIZER_PATH:', TOKENIZER_PATH)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jc0ZzYIW0TSN"
      },
      "source": [
        "## Perform sampling/inference"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aEe3p8geqekV"
      },
      "source": [
        "1. Load the RecurrentGemma model checkpoint with the [`recurrentgemma.jax.load_parameters`](https://github.com/google-deepmind/recurrentgemma/blob/e4939f9b7edf8baa1d512fb86bfc2e206044d66b/recurrentgemma/jax/utils.py#L31) method. The `sharding` argument set to `\"single_device\"` loads all model parameters on a single device."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Mnr52JQVqKRw"
      },
      "outputs": [],
      "source": [
        "import recurrentgemma\n",
        "from recurrentgemma import jax as recurrentgemma\n",
        "\n",
        "params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding=\"single_device\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-Xpnb2igrGjk"
      },
      "source": [
        "2. Load the RecurrentGemma model tokenizer, constructed using [`sentencepiece.SentencePieceProcessor`](https://github.com/google/sentencepiece/blob/4d6a1f41069c4636c51a5590f7578a0dbed83450/python/src/sentencepiece/__init__.py#L423):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-T0ZHff5rNSy"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "True"
            ]
          },
          "execution_count": 9,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "import sentencepiece as spm\n",
        "\n",
        "vocab = spm.SentencePieceProcessor()\n",
        "vocab.Load(TOKENIZER_PATH)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IkAf4fkNrY-3"
      },
      "source": [
        "3. To automatically load the correct configuration from the RecurrentGemma model checkpoint, use [`recurrentgemma.GriffinConfig.from_flax_params_or_variables`](https://github.com/google-deepmind/recurrentgemma/blob/e4939f9b7edf8baa1d512fb86bfc2e206044d66b/recurrentgemma/common.py#L128). Then, instantiate the [Griffin](https://arxiv.org/abs/2402.19427) model with [`recurrentgemma.jax.Griffin`](https://github.com/google-deepmind/recurrentgemma/blob/e4939f9b7edf8baa1d512fb86bfc2e206044d66b/recurrentgemma/jax/griffin.py#L29)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4PNWxDhvrRXJ"
      },
      "outputs": [],
      "source": [
        "model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(\n",
        "    flax_params_or_variables=params)\n",
        "\n",
        "model = recurrentgemma.Griffin(model_config)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vs0vgmXVroBq"
      },
      "source": [
        "3. Create a `sampler` with [`recurrentgemma.jax.Sampler`](https://github.com/google-deepmind/recurrentgemma/blob/e4939f9b7edf8baa1d512fb86bfc2e206044d66b/recurrentgemma/jax/sampler.py#L74) on top of the RecurrentGemma model checkpoint/weights and the tokenizer:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4GX4pFP6rtyN"
      },
      "outputs": [],
      "source": [
        "sampler = recurrentgemma.Sampler(\n",
        "    model=model,\n",
        "    vocab=vocab,\n",
        "    params=params,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "V9yU99Xxr59w"
      },
      "source": [
        "4. Write a prompt in `prompt` and perform inference. You can tweak `total_generation_steps` (the number of steps performed when generating a response — this example uses `50` to preserve host memory).\n",
        "\n",
        "**Note:** If you run out of memory, click on **Runtime** > **Disconnect and delete runtime**, and then **Runtime** > **Run all**."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gj9jRFI5Hrv2"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).\n",
            "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.\n",
            "  warnings.warn(\"Some donated buffers were not usable:\"\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Prompt:\n",
            "\n",
            "# 5+9=?\n",
            "Output:\n",
            "\n",
            "\n",
            "# Answer: 14\n",
            "\n",
            "# Explanation: 5 + 9 = 14.\n"
          ]
        }
      ],
      "source": [
        "prompt = [\n",
        "    \"\\n# 5+9=?\",\n",
        "]\n",
        "\n",
        "reply = sampler(input_strings=prompt,\n",
        "                total_generation_steps=50,\n",
        "                )\n",
        "\n",
        "for input_string, out_string in zip(prompt, reply.text):\n",
        "    print(f\"Prompt:\\n{input_string}\\nOutput:\\n{out_string}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bzKsCGIN0yX5"
      },
      "source": [
        "## Learn more\n",
        "\n",
        "- You can learn more about the Google DeepMind [`recurrentgemma`  library on GitHub](https://github.com/google-deepmind/recurrentgemma), which contains docstrings of methods and modules you used in this tutorial, such as [`recurrentgemma.jax.load_parameters`](https://github.com/google-deepmind/recurrentgemma/blob/e4939f9b7edf8baa1d512fb86bfc2e206044d66b/recurrentgemma/jax/utils.py#L31), [`recurrentgemma.jax.Griffin`](https://github.com/google-deepmind/recurrentgemma/blob/e4939f9b7edf8baa1d512fb86bfc2e206044d66b/recurrentgemma/jax/griffin.py#L29), and [`recurrentgemma.jax.Sampler`](https://github.com/google-deepmind/recurrentgemma/blob/e4939f9b7edf8baa1d512fb86bfc2e206044d66b/recurrentgemma/jax/sampler.py#L74).\n",
        "- The following libraries have their own documentation sites: [core JAX](https://jax.readthedocs.io), [Flax](https://flax.readthedocs.io), and [Orbax](https://orbax.readthedocs.io/).\n",
        "- For `sentencepiece` tokenizer/detokenizer documentation, check out [Google's `sentencepiece` GitHub repo](https://github.com/google/sentencepiece).\n",
        "- For `kagglehub` documentation, check out `README.md` on [Kaggle's `kagglehub` GitHub repo](https://github.com/Kaggle/kagglehub).\n",
        "- Learn how to [use Gemma models with Google Cloud Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).\n",
        "- Check out the [RecurrentGemma: Moving Past Transformers\n",
        "for Efficient Open Language Models](https://arxiv.org/pdf/2404.07839) paper by Google DeepMind.\n",
        "- Read the [Griffin: Mixing Gated Linear Recurrences with\n",
        "Local Attention for Efficient Language Models](https://arxiv.org/pdf/2402.19427) paper by GoogleDeepMind to learn more about the model architecture used by RecurrentGemma."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "Tce3stUlHN0L"
      ],
      "name": "recurrentgemma_jax_inference.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
